import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from Network.network import Network
from Network.network_utils import reduce_function
from Network.General.Conv.conv import ConvNetwork
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Factor.factored import FactoredNetwork, return_values


class PointNetwork(FactoredNetwork):
    def __init__(self, args):
        super().__init__(args)
        # assumes the input is flattened list of input space sized values
        # needs an object dim
        self.fp = args.factor_params
        
        self.conv = ConvNetwork(args)
        if args.factor.aggregate_final:
            final_args = copy.deepcopy(args)
            final_args.num_inputs = self.fp.output_dim if self.fp.reduce_fn != 'cat' else self.fp.output_dim * self.fp.num_queries
            final_args.hidden_sizes = args.factor.final_layers 
            self.MLP = MLPNetwork(**final_args)
            self.model = nn.Sequential(self.conv, self.MLP)
        else:
            self.model = nn.ModuleList([self.conv])
        self.train()
        self.reset_network_parameters()

    def forward(self, key, query, mask, ret_settings):
        x = query.transpose(1,2)
        embeddings = self.conv(x).transpose(2,1)
        if mask is not None: embeddings = embeddings * mask[:,0].unsqueeze(-1)
        if self.aggregate_final:
            x = reduce_function(self.fp.reduce_fn, embeddings)[0]
            reduction = x.view(-1, self.output_dim)
            x = self.MLP(reduction)
        return return_values(ret_settings, x, embeddings, reduction)